{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DOF and quadrature point plotter\n",
    "\n",
    "This notebook helps visualize the locations of MFEM's DOFs and quadrature points for various element types and quadrature rules.\n",
    "\n",
    "To use this, your python environment must have:\n",
    "* PyMFEM (https://github.com/mfem/PyMFEM)\n",
    "* Matplotlib (https://matplotlib.org)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preamble: import any necessary libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mfem.ser as mfem\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "import math\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "\n",
    "\n",
    "# Create some output directories for images\n",
    "for d in [\"figs\", \"figs/dofs\", \"figs/qpts\", \n",
    "          \"figs/dofs/svg\", \"figs/dofs/pdf\", \"figs/dofs/png\",\n",
    "          \"figs/qpts/svg\", \"figs/qpts/pdf\", \"figs/qpts/png\"]:\n",
    "    if not os.path.exists(d):\n",
    "        os.makedirs(d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup mesh and options\n",
    "\n",
    "Create an mfem mesh and define some parameters\n",
    "\n",
    "We currently support a simple Cartesian 2D mesh.\n",
    "This can be easily extended to load a mesh from file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Options for Cartesian mesh constructor are: \n",
    "# 0D: POINT\n",
    "# 1D: SEGMENT \n",
    "# 2D: TRIANGLE, QUADRILATERAL\n",
    "# 3D: TETRAHEDRON, HEXAHEDRON, WEDGE\n",
    "\n",
    "mesh = mfem.Mesh(1,1,\"QUADRILATERAL\")\n",
    "mesh.EnsureNodes()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define some parameters for the generated figures\n",
    "\n",
    "dim = mesh.Dimension()\n",
    "max_order = 6\n",
    "\n",
    "# Polynomial order of the FECollection\n",
    "orders = [i for i in range(max_order)]\n",
    "\n",
    "# Basis types of the FiniteElementCollections\n",
    "b_types = [ mfem.BasisType.GaussLobatto,\n",
    "           mfem.BasisType.GaussLegendre,\n",
    "           mfem.BasisType.Positive ]\n",
    "\n",
    "# Quadrature types for the integration rules\n",
    "q_types = [{'q': mfem.Quadrature1D.GaussLegendre, 'name': 'Gauss-Legendre'},\n",
    "           {'q': mfem.Quadrature1D.GaussLobatto, 'name': 'Gauss-Lobatto'},\n",
    "           {'q': mfem.Quadrature1D.OpenUniform, 'name': 'Open-Uniform'},\n",
    "           {'q': mfem.Quadrature1D.OpenHalfUniform, 'name': 'Open-Half-Uniform'},\n",
    "           {'q': mfem.Quadrature1D.ClosedGL, 'name': 'Closed-Gauss-Legendre'},\n",
    "           {'q': mfem.Quadrature1D.ClosedUniform, 'name': 'Closed-Uniform'}]\n",
    "\n",
    "# Finite Element collection types\n",
    "fec_types = [ \"H1\", \"L2\" ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to extract integration points and map them to physical space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getQptPositions(mesh, eid, ir):\n",
    "    \"\"\"Returns list of dictionaries of points in physical space for an element and integration rule.\n",
    "    \n",
    "       The dictionaries have entries for physical space position ('x', 'y')\n",
    "       and reference space position and weight ('ix', 'iy', 'w').\n",
    "    \"\"\"\n",
    "    \n",
    "    t = mesh.GetElementTransformation(eid)\n",
    "\n",
    "    pts = []\n",
    "    mfem.Vector(3)\n",
    "    for i in range(ir.GetNPoints()):\n",
    "        ip = ir.IntPoint(i)\n",
    "        v = t.Transform(ip)\n",
    "\n",
    "        d = {'x' : v[0], \n",
    "             'y' : v[1], \n",
    "             'ix': ip.x, \n",
    "             'iy': ip.y,\n",
    "             'w' : ip.weight}\n",
    "        #print(d)\n",
    "        pts.append(d)\n",
    "\n",
    "    return pts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getDofPositions(fespace, eid):\n",
    "    \"\"\"Returns a list of dictionaries containing the DOF positions \n",
    "    in physical and reference space\"\"\"\n",
    "    \n",
    "    mesh = fespace.GetMesh()\n",
    "    fe = fespace.GetFE(eid)\n",
    "    ir = fe.GetNodes()\n",
    "    \n",
    "    return getQptPositions(mesh, eid, ir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to generate plots for DOFs and quadrature points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plotDofPositions(name, pts, wscale = .1):\n",
    "    \"\"\"Creates a matplotlib plot for the DOFs in pts\n",
    "    \n",
    "       Note: Currently assumes all points are in a unit square.\n",
    "       TODO: Extend this to draw curved elements of the mesh.\n",
    "    \"\"\"\n",
    "    \n",
    "\n",
    "    plt.clf()\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6,6))\n",
    "    ax.set_xlim((-.1, 1.1))\n",
    "    ax.set_ylim((-.1, 1.1))\n",
    "\n",
    "    rect = plt.Rectangle((0, 0), 1, 1, linewidth=3, edgecolor='k', facecolor='none')\n",
    "    ax.add_patch(rect)\n",
    "\n",
    "    for p in pts:\n",
    "        circle=plt.Circle((p['x'],p['y']),  wscale, facecolor='#c0504d', edgecolor='#366092')\n",
    "        ax.add_patch(circle)\n",
    "\n",
    "    plt.axis('off')\n",
    "    ax.get_xaxis().set_visible(False)\n",
    "    ax.get_yaxis().set_visible(False)\n",
    "\n",
    "    print(F\"Saving DOFs for '{name}'\")\n",
    "    fig.savefig(F'figs/dofs/svg/{name}.svg') # bbox_inches=extent ?\n",
    "    fig.savefig(F'figs/dofs/png/{name}.png')\n",
    "    fig.savefig(F'figs/dofs/pdf/{name}.pdf')\n",
    "\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plotQptPositions(name, pts, wscale = .1, use_weights = True, min_size = None):\n",
    "    \"\"\"Creates a matplotlib plot for the quadrature points in pts\n",
    "    \n",
    "       Note: Currently assumes all points are in a unit square.\n",
    "       TODO: Extend this to draw curved elements of the mesh.\n",
    "       \n",
    "       Parameters:\n",
    "         name          The name for the output figures\n",
    "         pts           A collection of dictionaries of point data\n",
    "         wscale        Used to scale the quadrature weights\n",
    "         use_weights   When true, use the quadrature weights to scale the quadrature points\n",
    "         min_size      Optionally sets a lower bound on the size of the quadrature points\n",
    "    \"\"\"\n",
    "\n",
    "    plt.clf()\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6,6))\n",
    "    ax.set_xlim((-.1, 1.1))\n",
    "    ax.set_ylim((-.1, 1.1))\n",
    "\n",
    "    rect = plt.Rectangle((0, 0), 1, 1, linewidth=1.5, edgecolor='#969696', facecolor='none')\n",
    "    ax.add_patch(rect)\n",
    "\n",
    "    for p in pts:\n",
    "        # Color depends on sign\n",
    "        facecolor = '#0871b7C0' if (p['w'] >= 0) else '#B74E08C0'\n",
    "        # Scale w/ weights proportional to area\n",
    "        sc = wscale * math.sqrt(abs(p['w'])) if use_weights else wscale\n",
    "        # Apply threshold to size, if applicable \n",
    "        sc = min_size if (min_size and sc < min_size) else sc\n",
    "        \n",
    "        circle=plt.Circle((p['x'],p['y']),  sc, facecolor=facecolor, edgecolor='#231f20')\n",
    "        ax.add_patch(circle)\n",
    "\n",
    "    plt.axis('off')\n",
    "    ax.get_xaxis().set_visible(False)\n",
    "    ax.get_yaxis().set_visible(False)\n",
    "\n",
    "    print(F\"Saving qpts for '{name}'\")\n",
    "    fig.savefig(F'figs/qpts/svg/{name}.svg') # bbox_inches=extent ?\n",
    "    fig.savefig(F'figs/qpts/png/{name}.png')\n",
    "    fig.savefig(F'figs/qpts/pdf/{name}.pdf')\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the DOFs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Create a grid function for each FE type, order and basis type and plot figures\n",
    "# Skip the invalid combinations\n",
    "\n",
    "for f in fec_types:\n",
    "    for b in b_types:\n",
    "        for p in orders:\n",
    "            try:\n",
    "                if \"H1\" in f:\n",
    "                    fec = mfem.H1_FECollection(p, dim, b)\n",
    "                elif \"L2\" in f:\n",
    "                    fec = mfem.L2_FECollection(p, dim, b)\n",
    "                fespace = mfem.FiniteElementSpace(mesh, fec)      \n",
    "                \n",
    "                bname = mfem.BasisType.Name(b).split(\" \")[0]\n",
    "                print(F\"Working on {fec.Name()} -- dim {dim} -- basis type {bname} -- fec type {f}\" )\n",
    "\n",
    "                pts = getDofPositions(fespace, 0)\n",
    "                plotDofPositions(F'{fec.Name()}_{bname}', pts, 0.05)\n",
    "            except:\n",
    "                #print(F\"\\tFEC {fec.Name()} did not work: dim {dim} -- basis type {b} -- fec type {f}\" )\n",
    "                pass\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the quadrature points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Create a chart for each quadrature type and order defined above\n",
    "# Skip the invalid combinations\n",
    "# Note: Rules for 2*order and 2*order+1 are the same, so only plot the even ones\n",
    "\n",
    "# Currently hard-coded for squares\n",
    "g_type = mfem.Geometry.SQUARE\n",
    "g_name = \"square\"\n",
    "\n",
    "wscale=.125\n",
    "\n",
    "for q in q_types:\n",
    "    for o in orders:\n",
    "        try:\n",
    "            intrules = mfem.IntegrationRules(0, q['q'])\n",
    "            ir = intrules.Get(g_type, 2*o)\n",
    "            pts = getQptPositions(mesh, 0, ir)\n",
    "            pts = sorted(pts, key = lambda p: (p['x']-.5)**2 + (p['y']-.5)**2, reverse=True)\n",
    "                        \n",
    "            name = F\"qpts_{g_name}_{q['name']}_{dim}D_P{2*o}\"\n",
    "\n",
    "            #for p in pts:\n",
    "            #    print(F\"{name}: P{2*o} {p['w']}\")\n",
    "            \n",
    "            plotQptPositions(name, pts, wscale, True, min_size=None)\n",
    "            \n",
    "        except:\n",
    "            pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
